import torch
from torch import nn
import math
from model import embedding
from pytorch3d.renderer import ray_bundle_to_ray_points
from model.nerf_simple import NeuralRadianceField
from model.raymarcher_lebesgue import RayBundle, _jiggle_within_stratas

def binary_search_inverse(y, f, n_iter=12, tol=1e-3, t_min=0, t_max=16, **kwargs):
    '''
    Given a function f and a vector y of values, find the value of t=f^-1(y)
    
    :param y: The values were we want to do inversion
    :param f: the function we're trying to invert
    :param n_iter: the number of times to repeat the binary search, defaults to 12 (optional)
    :param t_min: The lower bound on the range of t, defaults to 0 (optional)
    :param t_max: The upper bound on the range of t, defaults to 16 (optional)
    :return: The inverse of the function f at points y.
    '''
    left = t_min * torch.ones_like(y)
    right = t_max * torch.ones_like(y)
    for _ in range(n_iter):
        middle = 0.5 * (right + left)
        f_at_x = f(middle, **kwargs)
        if torch.all(torch.abs(f_at_x - y) <= tol):
            return middle
        comparison_mask = y <= f_at_x
        left = torch.where(comparison_mask, left, middle)
        right = torch.where(comparison_mask, middle, right)
    t = 0.5 * (right + left)
    assert torch.all(t >= t_min), "t is lower than t_min"
    assert torch.all(t <= t_max), "t is higher than t_max"
    # atol = (t_max - t_min).max().item() / 2**(n_iter - 1)
    # assert torch.allclose(f(t, **kwargs), y, atol=atol, rtol=1e-4), f"f(t) not close to enough y, atol:{atol}, diff:{(f(t, **kwargs) - y).abs().max()}"
    return t

def _shifted_cumprod(x, shift=1):
    """
    Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
    ones and removes `shift` trailing elements to/from the last dimension
    of the result.
    """
    x_cumprod = torch.cumprod(x, dim=-2)
    x_cumprod_shift = torch.cat(
        [torch.ones_like(x_cumprod[..., :shift, :]), x_cumprod[..., :-shift, :]], dim=-2
    )
    return x_cumprod_shift


def in_cube(points, scene_size):
    """
    It checks if the points are in the cube [-1, 1]^3
    
    :param points: A tensor of shape ... x 3
    :return: A tensor of shape ... x 1
    """
    # points is a tensor of shape ... x 3
    coordinate_indicator = torch.logical_and(points >= -scene_size - 1e-2, points <= scene_size + 1e-2)
    return torch.all(coordinate_indicator, axis=-1, keepdim=True)


class NeRFLebesgue(NeuralRadianceField):
    def _prepare_density_layer(self, hparams, n_hidden_neurons):
        density_random_features_args = hparams.get('random_features_args', {})
        self.density_embedding = getattr(embedding, hparams.get('embedding_class', None), 'NerfEmbedding')(
                num_features=hparams.num_features, dimensions=hparams.dimensions, **density_random_features_args)
        self.density_in_features = self.density_embedding.output_features
        
        self.integration = hparams.integration
        self.max_depth = nn.Parameter(torch.log1p(torch.ones(1, dtype=torch.float) * hparams.max_depth), 
                requires_grad=hparams.get('trainable_depth', False))
        self.scene_size = hparams.get('scene_size', 1)
        self.density_layer = nn.Linear(self.density_in_features, 1)
        self._init_small_numbers(self.density_layer, self.density_in_features, hparams)


    def _prepare_raybundle(self, ray_bundle, t_grid_extracted, with_depth=False):
        '''
        Resample the raybundle to a new grid of time points
        
        :param ray_bundle: the original raybundle
        :param t_grid_extracted: the time grid of the extracted rays
        :return: A RayBundle object with the same number of rays as the input raybundle, but with the length
        of each ray being the length of the extracted grid.
        '''
        ray_bundle_resampled = RayBundle(
            origins=ray_bundle.origins, 
            directions=ray_bundle.directions, 
            lengths=t_grid_extracted * torch.expm1(self.max_depth) if with_depth else t_grid_extracted,
            original_lengths=ray_bundle.original_lengths,
            xys=None
        )
        return ray_bundle_resampled

    def _prepare_kwargs_from_bundle(self, ray_bundle):

        directions = ray_bundle.directions.view(-1, 
                ray_bundle.directions.shape[-1])
        origins = ray_bundle.origins.view(-1, 
                ray_bundle.origins.shape[-1])
        lengths = ray_bundle.lengths.view(-1, ray_bundle.lengths.shape[-1])
        origins_directions = RayBundle(origins=origins, directions=directions, lengths=lengths, xys=None)

        freq_along_rays = directions @ self.density_embedding.k.T #B F
        bias_along_rays = origins @ self.density_embedding.k.T + self.density_embedding.b[None]

        freq = freq_along_rays.unsqueeze(1) #B 1 F
        bias = bias_along_rays.unsqueeze(1) #B 1 F
        return {
            'ray_bundle' : origins_directions,
            'freq_emb' : freq, 
            'bias_emb' : bias
        }


    def forward(self, ray_bundle: RayBundle, **kwargs):
        '''
        Given a ray bundle, it computes the color values along the rays using custom  numerical integration scheme.
        
        :param ray_bundle: RayBundle
        :type ray_bundle: RayBundle
        :return: y_grid_extracted: the y-coordinates of the resampled rays
            color_values: the color values of the resampled rays
        '''
        extracted_shape = ray_bundle.directions.shape[:-1]
        extract = lambda tensor: tensor.view(*extracted_shape, -1)
        kwargs = self._prepare_kwargs_from_bundle(ray_bundle)

        # consider ray_bundle.lengths as ys:
        normalized_lengths = ray_bundle.lengths / torch.expm1(self.max_depth)
        normalized_original_lengths = ray_bundle.original_lengths / torch.expm1(self.max_depth)
        t_grid = normalized_lengths.view(-1, ray_bundle.lengths.shape[-1]).unsqueeze(-1) #BB P 1
        t_grid_original = normalized_original_lengths.view(-1, ray_bundle.original_lengths.shape[-1]).unsqueeze(-1) #BB P 1
        t_grid_original_extracted = extract(t_grid_original.squeeze(-1)) * torch.expm1(self.max_depth)

        if self.integration == 'riemann_exp_cumprod_exp':
            t_grid_extracted = extract(t_grid.squeeze(-1)) * torch.expm1(self.max_depth)
            ray_bundle_resampled = self._prepare_raybundle(ray_bundle, t_grid_extracted)
            color_values, _ = self._color(ray_bundle_resampled)
            sigma = self._get_sigma(t_grid * torch.expm1(self.max_depth), **kwargs)
            integrants = torch.exp(-sigma)
            absortion = _shifted_cumprod(integrants + 1e-10)
            pdf = absortion * (1.0 - integrants)
            pdf_values_extracted = extract(pdf.squeeze(-1)).unsqueeze(-1)

        if self.integration == 'riemann_x_cumprod_exp':
            t_grid_extracted = extract(t_grid.squeeze(-1)) * torch.expm1(self.max_depth)
            ray_bundle_resampled = self._prepare_raybundle(ray_bundle, t_grid_extracted)
            color_values, _ = self._color(ray_bundle_resampled)
            sigma = self._get_sigma(t_grid * torch.expm1(self.max_depth), **kwargs)
            integrants = torch.exp(-sigma)
            absortion = _shifted_cumprod(integrants + 1e-10)
            pdf = absortion * sigma
            pdf_values_extracted = extract(pdf.squeeze(-1)).unsqueeze(-1)
        
        if self.integration == 'riemann_sum':
            pdf_values = self.pdf(t_grid * torch.expm1(self.max_depth), **kwargs)
            t_grid_extracted = extract(t_grid.squeeze(-1)) * torch.expm1(self.max_depth)
            pdf_values_extracted = extract(pdf_values.squeeze(-1)).unsqueeze(-1)
            ray_bundle_resampled = self._prepare_raybundle(ray_bundle, t_grid_extracted)
            color_values, _ = self._color(ray_bundle_resampled)
        
        if self.integration.startswith('riemann'):
            return {
                'grid': t_grid_original_extracted.unsqueeze(-1), 
                'avg_values' : color_values * pdf_values_extracted
            }
        
        assert self.integration.startswith('lebesgue'), 'check spelling in configs'
        t_min, t_max = self._get_scene_boundaries(**kwargs)
        kwargs['antiderivative_zero'] = self._sigma_antiderivative(t_min, **kwargs)
        y_min = self.cdf(t_min, **kwargs)
        y_max = self.cdf(t_max, **kwargs)

        # use sampled lengths as input for y, normalized to [0, 1] and then to [y_min, y_max]
        y_grid_noisy = normalized_lengths.view(-1, normalized_lengths.shape[-1]).unsqueeze(-1) #BB P 1
        y_grid_noisy = y_min + y_grid_noisy * (y_max - y_min)
        y_grid_original = normalized_original_lengths.view(-1, normalized_original_lengths.shape[-1]).unsqueeze(-1) #BB P 1
        y_grid_original = y_min + y_grid_original * (y_max - y_min)

        if self.integration == 'lebesgue_sum_v1':
            t_grid = self.inv_opacity(y_grid_noisy, t_min=t_min, t_max=t_max, **kwargs)

        if self.integration == 'lebesgue_sum_v2':
            with torch.no_grad():
                t_grid = self.inv_opacity(y_grid_noisy, t_min=t_min, t_max=t_max, **kwargs)
            y_grid_noisy = self.opacity(t_grid, **kwargs)     
        
        t_grid_extracted = extract(t_grid.squeeze(-1))
        y_grid_extracted = extract(y_grid_original.squeeze(-1))
        ray_bundle_resampled = self._prepare_raybundle(ray_bundle, t_grid_extracted)
        color_values, _ = self._color(ray_bundle_resampled)

        return {
            'grid': y_grid_extracted.unsqueeze(-1),
            'avg_values': color_values,
            'opacity': extract(y_max.squeeze(-1)),
            't_grid': t_grid_extracted.unsqueeze(-1),
            # 'y_grid': y_grid_extracted,
            'y_min': extract(y_min.squeeze(-1))
        }


    def _get_scene_boundaries(self, ray_bundle, **kwargs):
        """
        It checks if the point is in the cube, and if it is,
        it returns the min and max t along rays
        
        :param ray_bundle: namedtuple of tensors origins, directions of shape (batch_size, n_rays, 3)
        :return: t_min, t_max
        """
        t_neq = -(self.scene_size + ray_bundle.origins) / ray_bundle.directions # BB x 3
        t_pos = (self.scene_size - ray_bundle.origins) / ray_bundle.directions # BB x 3
        t_possible_positions = torch.cat([t_neq, t_pos], dim=-1) # BB x 6

        ray_bundle_possible_positions = self._prepare_raybundle(ray_bundle, t_possible_positions)
        possibe_points = ray_bundle_to_ray_points(ray_bundle_possible_positions) # BB x 6 x 3
        t_possible_positions = t_possible_positions.unsqueeze(-1) # BB x 6 x 1

        # we need to validate that t inside cube, because we can collide outside cube for the first/last time
        in_cube_indicator = in_cube(possibe_points, self.scene_size) # BB x 6 x 1
        # scene_indicators = torch.logical_and(t_possible_positions >= 0, in_cube_indicator) # BB x 6 x 1
        # find t_min, t_max
        t_max = torch.max(torch.where(in_cube_indicator, t_possible_positions, torch.zeros_like(t_possible_positions)), dim=-2).values
        t_min = torch.min(torch.where(in_cube_indicator, t_possible_positions, t_max.unsqueeze(-2)), dim=-2).values
        t_min = torch.where(t_min == t_max, torch.zeros_like(t_min), t_min)
        t_min = t_min.clamp_(min=0)
        t_max = t_max.clamp_(min=0)
        return t_min.unsqueeze(-1), t_max.unsqueeze(-1)

        
    def _get_sigma(self, t, freq_emb, bias_emb, **kwargs):
        embedding = torch.cos(torch.bmm(t, freq_emb) + bias_emb) 
        linear = self.density_layer(embedding)
        return linear ** 2
    
    def transparency(self, t, antiderivative_zero, **kwargs):
        """
        sigma(s) = (\sum_i a_i cos(u_i * s + v_i) + b_0) ** 2
        computes \int_{0}^{t} sigma(s) d_s in closed form
        """
        int_sigma = self._sigma_antiderivative(t, **kwargs) - antiderivative_zero
        #return torch.exp(-int_sigma - 1e-10)
        return torch.exp(-int_sigma)
    
    def opacity(self, t, **kwargs):
        survival = self.transparency(t, **kwargs)
        return 1. - survival
        
    def pdf(self, t, **kwargs):
        return self._get_sigma(t, **kwargs) * self.transparency(t, **kwargs)


    def inv_opacity(self, y, antiderivative_zero, t_min=0, t_max=16, **kwargs):
        """
        a wrapper to get the desired autograd behavior
        """
        with torch.no_grad():
            t = binary_search_inverse((1 - y).log().neg(),
                                      self._sigma_int,
                                      self.hparams.density.get('bin_search_n_iter', 12),
                                      self.hparams.density.get('bin_search_tol', 1e-3),
                                      t_min,
                                      t_max,
                                      antiderivative_zero=antiderivative_zero,
                                      **kwargs)
            sigma = self._get_sigma(t, **kwargs)
            
        # a nasty trick to reattach the gradient
        sigma_int = self._sigma_int(t, antiderivative_zero, **kwargs)
        # gradient_surrogate = -sigma_int / sigma
        gradient_surrogate = (y * (sigma_int).exp().detach() - sigma_int) / (sigma + self.hparams.get('eps_gradient', 0))
        output = t + gradient_surrogate - gradient_surrogate.detach()
        return output
    
    cdf = opacity
    survival = transparency
    inv_cdf = inv_opacity
    
    def _sigma_int(self, t, antiderivative_zero, **kwargs):
        return self._sigma_antiderivative(t, **kwargs) - antiderivative_zero

    def _sigma_antiderivative(self, t, freq_emb, bias_emb, **kwargs):
        """
        sigma(s) = (\sum_i a_i cos(u_i * s + v_i) + b_0) ** 2
        computes \int sigma(s) d_s in antiderivative
        
        1. a_i a_j cos(u_i * s + v_i) cos(u_j * s  + v_j)
            1.1 i != j
                sin((u_i - u_j) t + (v_i - v_j)) + sin((u_i + u_j) t + (v_i + v_j))
            1.2 i == j
            
        2. 2 * b_0 * a_i cos(u_i * s + v_i)
        
        3. b_0 ** 2
        """
        embedding = t * freq_emb + bias_emb # B points_along_ray num_features
        # \int cos s cos s ds
        # i != j
        freq_diff = freq_emb.transpose(-1, -2) - freq_emb # B num_features num_features
        freq_sum = freq_emb.transpose(-1, -2) + freq_emb # B num_features num_features
        
        cos_cos_i_neq_j = 0.5 * (
            torch.sin(torch.unsqueeze(embedding, -1) - torch.unsqueeze(embedding, -2)) / freq_diff.unsqueeze(1)
            + torch.sin(torch.unsqueeze(embedding, -1) + torch.unsqueeze(embedding, -2)) / freq_sum.unsqueeze(1)
        )  # B points_along_ray num_features num_features

        diagonal_mask = torch.eye(self.density_in_features, dtype=torch.bool, device=embedding.device).unsqueeze(0).unsqueeze(0)
        cos_cos_i_neq_j.masked_fill_(diagonal_mask, 0.)
        cos_cos_i_neq_j = cos_cos_i_neq_j * (self.density_layer.weight * torch.t(self.density_layer.weight))
        cos_cos_i_neq_j = cos_cos_i_neq_j.sum(axis=-1).sum(axis=-1, keepdim=True) # B points_along_ray 1
        
        # i == j
        cos_cos_i_eq_j = 0.5 * (embedding + torch.sin(embedding) * torch.cos(embedding)) / freq_emb
        cos_cos_i_eq_j = (self.density_layer.weight ** 2 * cos_cos_i_eq_j).sum(-1, keepdim=True) # B points_along_ray 1
        
        cos_cos = cos_cos_i_eq_j + cos_cos_i_neq_j
        # \int cos s ds
        cos = torch.sin(embedding) / freq_emb
        cos = 2 * self.density_layer.bias * self.density_layer.weight * cos
        cos = cos.sum(axis=-1, keepdim=True) # B points_along_ray 1
        # \int ds
        bias = self.density_layer.bias ** 2 * t
        return cos_cos + cos + bias

    def batched_forward(
        self, 
        ray_bundle: RayBundle,
        split_size: int = 256,
        **kwargs,        
    ):

        # Parse out shapes needed for tensor reshaping in this function.
        n_pts_per_ray = ray_bundle.lengths.shape[-1]
        same_size = ray_bundle.origins.shape[:-1]
        spatial_size = ray_bundle.origins.shape[:-1]

        # Split the rays to `split_size` batches.
        tot_samples = ray_bundle.origins.shape[:-1].numel()
        batches = torch.chunk(torch.arange(tot_samples), math.ceil(tot_samples / self.hparams.get('val_split_size', split_size)))

        # For each batch, execute the standard forward pass.
        batch_outputs = [
            self.forward(
                RayBundle(
                    origins=ray_bundle.origins.view(-1, 3)[batch_idx],
                    directions=ray_bundle.directions.view(-1, 3)[batch_idx],
                    lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
                    original_lengths=ray_bundle.original_lengths.view(-1, n_pts_per_ray + 1)[batch_idx],
                    xys=None,
                )
            ) for batch_idx in batches
        ]
        
        # Collate output dict
        batch_outputs_dict = {
            'grid' : torch.cat(
                [batch_output['grid'] for batch_output in batch_outputs], dim=0
            ).view(*spatial_size, n_pts_per_ray + 1, -1) 
        }
        batch_outputs_dict.update({
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*spatial_size, n_pts_per_ray, -1) for output_i in ['avg_values', 't_grid']
        })
        batch_outputs_dict.update({
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*same_size, -1) for output_i in ['opacity', 'y_min']
        })
        return batch_outputs_dict


class NeRFSpline(NeRFLebesgue):
    def _prepare_density_layer(self, *args, **kwargs):
        super(NeRFLebesgue, self)._prepare_density_layer(*args, **kwargs)

    def _prepare_kwargs_from_bundle(self, ray_bundle):
        t = ray_bundle.lengths
        # ray_points = ray_bundle_to_ray_points(ray_bundle)
        
        # spatial_shape = ray_points.shape[:-1]
        # ray_points = ray_points.view(-1, 3)
        # ray_embeddings = self.density_embedding(ray_points)
        # s = self.density_mlp(ray_embeddings).squeeze(-1)
        # s = s.view(*spatial_shape)

        _, color_features = self._color(ray_bundle)
        density_features = self._get_density_features(ray_bundle, color_features)
        rays_densities = self._get_densities(density_features, convert_to_alphas=False).squeeze(-1)
        rays_densities = rays_densities + self.hparams.get('epsilon_min_density', 1e-3)

        return dict(ray_bundle=ray_bundle, t_vals=t, s_vals=rays_densities)


    def _prepare_raybundle(self, ray_bundle, t_samples):
        '''
        Resample the raybundle to a new grid of time points
        
        :param ray_bundle: the original raybundle
        :param t_grid_extracted: the time grid of the extracted rays
        :return: A RayBundle object with the same number of rays as the input raybundle, but with the length
        of each ray being the length of the extracted grid.
        '''
        ray_bundle_resampled = RayBundle(
            origins=ray_bundle.origins, 
            directions=ray_bundle.directions, 
            lengths=t_samples,
            original_lengths=ray_bundle.original_lengths,
            xys=None
        )
        return ray_bundle_resampled


    def forward(self, ray_bundle: RayBundle, **kwargs):
        t_min, _ = ray_bundle.lengths.min(-1, keepdim=True)
        t_max, _ = ray_bundle.lengths.max(-1, keepdim=True)
        kwargs = self._prepare_kwargs_from_bundle(ray_bundle)
        kwargs['antiderivative_zero'] = self._sigma_antiderivative(t_min + self.hparams.get('epsilon', 1e-3), **kwargs)
        y_min = self.opacity(t_min + self.hparams.get('epsilon', 1e-3), **kwargs)
        y_max = self.opacity(t_max - self.hparams.get('epsilon', 1e-3), **kwargs)
        # get y_grid and samples, this part can be integrated into ray sampler
        y_grid = torch.linspace(1e-8, 1 - 1e-8, self.hparams.color.get('n_color_samples', 8) + 1, device=y_max.device)
        y_grid = y_min + (y_max - y_min) * y_grid
        y_samples = _jiggle_within_stratas(y_grid, single_sample_per_bin=True)
        t_samples = self.inv_opacity(y_samples, t_min=t_min + self.hparams.get('epsilon', 1e-3), t_max=t_max - self.hparams.get('epsilon', 1e-3), **kwargs)

        ray_bundle_resampled = self._prepare_raybundle(ray_bundle, t_samples)
        color_values, _ = self._color(ray_bundle_resampled)

        return {
            'grid': y_grid.unsqueeze(-1),
            'avg_values': color_values,
            'opacity': y_max,
            'y_min': y_min,
            't_grid': t_samples.unsqueeze(-1),
        }


    def _get_sigma(self, t, **kwargs):
        t_l = kwargs['t_vals'][..., None, :-1]
        t_r = kwargs['t_vals'][..., None, 1:]
        s_l = kwargs['s_vals'][..., None, :-1]
        s_r = kwargs['s_vals'][..., None, 1:]
        t = t.unsqueeze(-1)
         
        slope = (s_r - s_l) / (t_r - t_l)
        offset = s_l - slope * t_l
         
        bin_indicator = torch.logical_and((t < t_r), (t >= t_l))
        slope = slope.masked_fill(~bin_indicator, 0).sum(-1)
        offset = offset.masked_fill(~bin_indicator, 0).sum(-1)
        return slope * t.squeeze(-1) + offset
        

    def _sigma_antiderivative(self, t, **kwargs):
        t_l = kwargs['t_vals'][..., None, :-1]
        t_r = kwargs['t_vals'][..., None, 1:]
        s_l = kwargs['s_vals'][..., None, :-1]
        s_r = kwargs['s_vals'][..., None, 1:]
        t = t.unsqueeze(-1)
        
        slope = (s_r - s_l) / (t_r - t_l)
        offset = s_l - slope * t_l
        full_bin_indicator = t_r < t
        curr_bin_indicator = torch.logical_and(
            t < t_r, t >= t_l)
        # integrate from t_l till t_r bins on the left
        full_part = (0.5 * slope * (t_r ** 2 - t_l ** 2) + offset * (t_r - t_l))
        full_part = full_part.masked_fill(~full_bin_indicator, 0).sum(-1)
        # integrate current bin till t
        slope = slope.masked_fill(~curr_bin_indicator, 0).sum(-1)
        offset = offset.masked_fill(~curr_bin_indicator, 0).sum(-1)
        t_l = t_l.masked_fill(~curr_bin_indicator, 0).sum(-1)
        t = t.squeeze(-1)
        curr_part = (0.5 * slope * (t ** 2 - t_l ** 2) + offset * (t - t_l))
        return full_part + curr_part


    def batched_forward(
        self, 
        ray_bundle: RayBundle,
        split_size: int = 256,
        **kwargs,        
    ):

        # Parse out shapes needed for tensor reshaping in this function.
        #n_pts_per_ray = ray_bundle.lengths.shape[-1]
        batch_shape = ray_bundle.origins.shape[:-1]
        batch_len = batch_shape.numel()


        # Split the rays to `split_size` batches.
        batches = torch.chunk(torch.arange(batch_len), math.ceil(batch_len / self.hparams.get('val_split_size', split_size)))

        # For each batch, execute the standard forward pass.
        batch_outputs = [
            self.forward(
                RayBundle(
                    origins=ray_bundle.origins.view(batch_len, -1)[batch_idx],
                    directions=ray_bundle.directions.view(batch_len, -1)[batch_idx],
                    lengths=ray_bundle.lengths.view(batch_len, -1)[batch_idx],
                    original_lengths=ray_bundle.original_lengths.view(batch_len, -1)[batch_idx],
                    xys=None,
                )
            ) for batch_idx in batches
        ]
        
        batch_outputs_dict = {
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*batch_shape, -1, 1) for output_i in ['grid', 'avg_values', 't_grid']
        }
        batch_outputs_dict.update({
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*batch_shape, -1, 3) for output_i in ['avg_values']
        })
        batch_outputs_dict.update({
            output_i : torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*batch_shape, 1) for output_i in ['opacity']
        })
        return batch_outputs_dict